import logging
from openai import OpenAI
import tqdm
from datasets import Dataset, load_dataset, concatenate_datasets
import os
import argparse
import json
import random
import time

def make_map_fn(name, split, apply_chat_template):
    def process_fn(example, idx):
        if not apply_chat_template:
            prompt = generate_base_prompt(example, template_type=args.template_type)
        else:
            if name == 'HiToM' or name == 'HiToM_Third' or name == 'HiToM_Fourth':
                choices_text = example["choices"]
                if isinstance(choices_text, list):
                    choices_text = ", ".join(choices_text)
                entire_instruction = f"Story: {example['story']} Question: {example['question']} Choices: {choices_text}"
                answer = example['answer']
            if name == 'ToMI':
                choices_text = example["containers"]
                if isinstance(choices_text, list):
                    choices_text = ", ".join(choices_text)
                entire_instruction = f"Story: {example['story']} Question: {example['question']} Choices: {choices_text}"
                answer = example['answer']
            if name == 'ExploreToM':
                entire_instruction = f"Story: {example['story_structure']} Question: {example['question']}"
                answer = example['expected_answer']
            if name == 'ToMbench':
                option_A = example["OPTION-A"]
                option_B = example["OPTION-B"]
                option_C = example["OPTION-C"]
                option_D = example["OPTION-D"]

                formatted_string = ""
                formatted_string += "A: " + option_A + " "
                formatted_string += "B: " + option_B
                if option_C != None:
                    formatted_string += " " + "C: " + option_C
                if option_D != None:
                    formatted_string += " " + "D: " + option_D
                entire_instruction = f"Story: {example['STORY']} Question: {example['QUESTION']} Choices: {formatted_string}"
                if example["答案\nANSWER"] == 'A':
                    answer = option_A
                elif example["答案\nANSWER"] == 'B':
                    answer = option_B
                elif example["答案\nANSWER"] == 'C':
                    answer = option_C
                else:
                    answer = option_D
            if name == 'Socialqa':
                answer_A = example["answerA"]
                answer_B = example["answerB"]
                answer_C = example["answerC"]

                formatted_string = ""
                formatted_string += "A: " + answer_A + " " + "B: " + answer_B + " " + "C: " + answer_C
                entire_instruction = f"Context: {example['context']} Question: {example['question']} Choices: {formatted_string}"
                if example["label_letter"] == "A":
                    answer = answer_A
                elif example["label_letter"] == "B":
                    answer = answer_B
                else:
                    answer = answer_C
            if name == 'SimpleToM_mental' or name == 'SimpleToM_behavior' or name == 'SimpleToM_judgment':
                data = example["choices"]
                result = ""
                answer_A = data["text"][0]
                answer_B = data["text"][1]
                result += "A: " + answer_A + " " + "B: " + answer_B
                print(result)
                entire_instruction = f"story: {example['story']} question: {example['question']} Choices: {result}"
                if example["answerKey"] == 'A':
                    answer = data["text"][0]
                else:
                    answer = data["text"][1]
            if name == 'ToMATO_first' or name == 'ToMATO_second':
                a1 = example["a0"]
                a2 = example["a1"]
                a3 = example["a2"]
                a4 = example["a3"]
                
                formatted_string = ""
                formatted_string += "A: " + a1 + " " + "B: " + a2 + " " + "C: " + a3 + " " + "D: " + a4

                entire_instruction = f"Conversation: {example['conversation']} Question: {example['q']} Choices: {formatted_string}"
                answer = example["a_str"]

            if name == 'OpenToM_attitude':
                entire_instruction = f"story: {example['narrative']} Question: {example['question']} Choices: positive, negative, neutral"
                answer = example["answer"]

            if name == 'OpenToM_location_cg_fo' or name == 'OpenToM_location_cg_so':
                entire_instruction = f"story: {example['narrative']} Question: {example['question']} Choices: Yes, No"
                answer = example["answer"]

            prompt = [
                {
                    "role": "system"
                },
                {
                    "role": "user",
                    "content": entire_instruction,
                }
            ]
        
        data = {
            "data_source": name,
            "prompt": prompt,
            "ability": "logic",
            "reward_model": {
                "style": "rule",
                "ground_truth": answer
            },
            "extra_info": {
                'split': split,
                'index': idx,
                'apply_chat_template': apply_chat_template
            }
        }
        return data
    return process_fn    

def Evaluate(prompt:str, model_name, max_retries=30):
        
    for i in range(max_retries):
        try:
            client = OpenAI(
                api_key="",
                base_url="",
            )

            completion = client.chat.completions.create(
                model=model_name,  
                messages=[
                    {'role': 'system', 'content': 'You are a helpful assistant.'},
                    {'role': 'user', 'content': prompt}
                ],
                temperature = 0
            )
            print(completion.choices[0].message.content)
            output = completion.choices[0].message.content
                #tokens = completion.usage
            return output
        except Exception as e:
            if i == max_retries - 1:  # If this was the last attempt
                raise  # re-throw the last exception
            else:
                # Wait for a bit before retrying and increase the delay each time
                sleep_time = (2 ** i) + random.random()  # Exponential backoff with full jitter
                time.sleep(sleep_time)    

def getOutput_OpenAI(prompt:str, model_name, max_retries=30):
        
    client = OpenAI(
    api_key="",
    )

    m = [
        #{"role": "system", "content": "Read the following social event related to you and answer the questions."},
        {'role': 'user', 'content': prompt},
    ]
        

    for i in range(max_retries):
        try:
            res = client.chat.completions.create(
                model=model_name,
                messages=m
            )

            output = res.choices[0].message.content.strip()
            tokens = res.usage
            #print(tokens)
                
            return output, tokens
        except Exception as e:
            # Exponential backoff
            print(e)
            time.sleep(15)


def getOutput_DeepSeek(prompt:str, max_retries=30):

    for i in range(max_retries):
        client = OpenAI(api_key="", base_url="")

        reasoning_content = ""  
        answer_content = ""    
        is_answering = False   

        completion = client.chat.completions.create(
            model='deepseek-r1',  
            messages=[
                {'role': 'user', 'content': prompt}
            ],
            stream = True
        )

        for chunk in completion:
            
            if not chunk.choices:
                print("\nUsage:")
                print(chunk.usage)
            else:
                delta = chunk.choices[0].delta
                
                if hasattr(delta, 'reasoning_content') and delta.reasoning_content != None:
                    print(delta.reasoning_content, end='', flush=True)
                    reasoning_content += delta.reasoning_content
                else:
                    
                    if delta.content != "" and is_answering == False:
                        print("\n" + "=" * 20 + "" + "=" * 20 + "\n")
                        is_answering = True
                    
                    print(delta.content, end='', flush=True)
                    answer_content += delta.content

        prediction = reasoning_content + answer_content
        tokens = "existing"

        return prediction, tokens



def evaluate_all_test():
    model_realname = args.eval_model
    logging.basicConfig(
    filename=f"log_{model_realname}_all.log",      
    level=logging.INFO,         
    format='%(asctime)s - %(levelname)s - %(message)s'
    )
    # Load the dataset
    gold = []             # Ground truths
    predictions = []      # Predictions
    results = []          # Results (accuracy)
    category_results = {} # Results per category
    category_percents = {}
    #bad = 0         
    correctNum = 0
    totalNum = 0
    #false_case = []
    #false_question_type = []
    #index_false = []
    
    # Print stuff
    print("\n------------------------")
    print("    EVALUATING      ")
    print("------------------------")
    print(f"EVAL MODEL: {args.eval_model}")
    
    # print(f"CATEGORY: {args.category}")
    print("------------------------\n")
    logging.info("------------------------")
    logging.info("    EVALUATING    ")
    logging.info("------------------------")
    logging.info(f"EVAL MODEL: {args.eval_model}")

    

    dataset_test_hitom_4 = load_dataset(".../ToM_data/Hi-ToM", split="train[660:680]")
    dataset_test_hitom_5 = load_dataset(".../ToM_data/Hi-ToM", split="train[760:780]")
    dataset_test_hitom_6 = load_dataset(".../ToM_data/Hi-ToM", split="train[860:880]")

    dataset_test_hitom_third = concatenate_datasets([dataset_test_hitom_4, dataset_test_hitom_5, dataset_test_hitom_6])

   

    dataset_test_hitom_10 = load_dataset(".../ToM_data/Hi-ToM", split="train[680:700]")
    dataset_test_hitom_11 = load_dataset(".../ToM_data/Hi-ToM", split="train[780:800]")
    dataset_test_hitom_12 = load_dataset(".../ToM_data/Hi-ToM", split="train[880:900]")

    dataset_test_hitom_fourth = concatenate_datasets([dataset_test_hitom_10, dataset_test_hitom_11, dataset_test_hitom_12])

    dataset_test_ToMbench = load_dataset("json", data_files=".../ToMbench_data/test_combined.json", split="train[:431]")

    dataset_test_socialqa = load_dataset("json", data_files=".../SocialIqa/socialIWa_v1.4_tst_wDims.json", split="train[:120]")


    dataset_test_simpletom_mental = load_dataset("json", data_files=".../SimpleToM/mental-state-qa/test.json", split="train[:120]")
    dataset_test_simpletom_behavior = load_dataset("json", data_files=".../SimpleToM/behavior-qa/test.json", split="train[:120]")
    dataset_test_simpletom_judgment = load_dataset("json", data_files=".../SimpleToM/judgment-qa/test.json", split="train[:120]")

    dataset_test_tomato_first = load_dataset("json", data_files=".../ToMATO/dataset/tomato_first.json", split="train[:25]")

    dataset_test_tomato_second = load_dataset("json", data_files=".../ToMATO/dataset/tomato_second.json", split="train[:25]")

    ##25
    dataset_test_opentom_attitude = load_dataset("json", data_files=".../OpenToM/merged_attitude_data.json")
    dataset_test_opentom_attitude = dataset_test_opentom_attitude['train']
    
    ##30
    dataset_test_opentom_location_cg_fo = load_dataset("json", data_files=".../OpenToM/merged_location_cg_fo_data.json")
    dataset_test_opentom_location_cg_fo = dataset_test_opentom_location_cg_fo['train']
    
    ##30
    dataset_test_opentom_location_cg_so = load_dataset("json", data_files=".../OpenToM/merged_location_cg_so_data.json")
    dataset_test_opentom_location_cg_so = dataset_test_opentom_location_cg_so['train']

    apply_chat_template = True

    test_dataset_hitom_third = dataset_test_hitom_third.map(function=make_map_fn('HiToM_Third', 'test', apply_chat_template=apply_chat_template), with_indices=True)
    test_dataset_hitom_fourth = dataset_test_hitom_fourth.map(function=make_map_fn('HiToM_Fourth', 'test', apply_chat_template=apply_chat_template), with_indices=True)
    test_dataset_tombench = dataset_test_ToMbench.map(function=make_map_fn('ToMbench', 'test', apply_chat_template=apply_chat_template), with_indices=True)
    test_dataset_socialqa = dataset_test_socialqa.map(function=make_map_fn('Socialqa', 'test', apply_chat_template=apply_chat_template), with_indices=True)
    
    ##OOD
    test_dataset_simpletom_mental = dataset_test_simpletom_mental.map(function=make_map_fn('SimpleToM_mental', 'test', apply_chat_template=apply_chat_template), with_indices=True)
    #print(test_dataset_simpletom_mental)
    test_dataset_simpletom_mental = test_dataset_simpletom_mental.remove_columns('choices')
    print(test_dataset_simpletom_mental)

    test_dataset_simpletom_behavior = dataset_test_simpletom_behavior.map(function=make_map_fn('SimpleToM_behavior', 'test', apply_chat_template=apply_chat_template), with_indices=True)
    test_dataset_simpletom_behavior = test_dataset_simpletom_behavior.remove_columns('choices')

    test_dataset_simpletom_judgment = dataset_test_simpletom_judgment.map(function=make_map_fn('SimpleToM_judgment', 'test', apply_chat_template=apply_chat_template), with_indices=True)
    test_dataset_simpletom_judgment = test_dataset_simpletom_judgment.remove_columns('choices')

    test_dataset_tomato_first = dataset_test_tomato_first.map(function=make_map_fn('ToMATO_first', 'test', apply_chat_template=apply_chat_template), with_indices=True)
    #print(test_dataset_tomato_first)
    test_dataset_tomato_second = dataset_test_tomato_second.map(function=make_map_fn('ToMATO_second', 'test', apply_chat_template=apply_chat_template), with_indices=True)

    test_dataset_opentom_attitude = dataset_test_opentom_attitude.map(function=make_map_fn('OpenToM_attitude', 'test', apply_chat_template=apply_chat_template), with_indices=True)
    #print(test_dataset_opentom_attitude)
    test_dataset_opentom_location_cg_fo = dataset_test_opentom_location_cg_fo.map(function=make_map_fn('OpenToM_location_cg_fo', 'test', apply_chat_template=apply_chat_template), with_indices=True)
    test_dataset_opentom_location_cg_so = dataset_test_opentom_location_cg_so.map(function=make_map_fn('OpenToM_location_cg_so', 'test', apply_chat_template=apply_chat_template), with_indices=True)


    test_dataset = concatenate_datasets([test_dataset_hitom_third, test_dataset_hitom_fourth, test_dataset_tombench, test_dataset_socialqa, test_dataset_simpletom_mental, test_dataset_simpletom_behavior, test_dataset_simpletom_judgment, test_dataset_tomato_first, test_dataset_tomato_second, test_dataset_opentom_attitude, test_dataset_opentom_location_cg_fo, test_dataset_opentom_location_cg_so])

    for idx, sample in enumerate(test_dataset):
        entire_prompt = sample['prompt'][1]['content']
        answer = sample['reward_model']['ground_truth']
        data_source = sample['data_source']
        if 'deepseek' in model_realname:
            result, tokens = getOutput_DeepSeek(entire_prompt)
        if 'o3' in model_realname or 'gpt-4o' in model_realname:
            result, tokens = getOutput_OpenAI(entire_prompt, model_realname)

        
            
        prompt = f"""\
This is someone's response [{result}] to a context: [{entire_prompt}]:


This is the correct answer:

[{answer}]

Is final answer correct? Output 'True' or 'False' only.
"""

        graded_answer = Evaluate(prompt, 'deepseek-v3')
        if graded_answer == "True":
            correct = True
            correctNum += 1
        else:
            correct = False
        totalNum += 1
            

          
        # if not correct:
        #     # This means the model got it wrong.
        logging.info(f"Index: {idx}")
        #logging.info(f"Story: {story}")
        #logging.info(f"Question: {question}")
        logging.info(f"Entire_Context: {entire_prompt}")
        logging.info(f"Prediction: {result}")
        logging.info(f"Tokens: {tokens}")
        logging.info(f"Label: {answer}")
        logging.info(f"**********Correct**********: {correct}")
        logging.info(f"data_source: {data_source}")
        #logging.info(f"Story_Type: {story_length}")
        #logging.info(f"Question_Type: {question_order}")


        # Append ground truth and model prediction.
        gold.append(answer)
        predictions.append(result)
            
        # Calculate category result
        temp = category_results.get("count"+"_"+str(data_source), {"correct": 0, "total" : 0})
        if correct:
            temp["correct"] += 1
        temp["total"] += 1
        percent = temp["correct"] / temp["total"]
        category_results["count"+"_"+str(data_source)] = temp
        category_percents[str(data_source)] = percent
                    
            

    accuracy = correctNum / totalNum
    print(correctNum)
    print(totalNum)
    print(f"Accuracy: {accuracy*100:.3f}%")
    results.append(f"Accuracy : {accuracy:.3f}")

    logging.info(f"category_results: {category_results}")
    logging.info(f"category_percents: {category_percents}")
    logging.info(f"Accuracy: {accuracy*100:.3f}%")




    # Print results
    print("\n------------------------")
    print("         RESULTS        ")
    print("------------------------")
    print(f"MODEL: {args.eval_model}")
    print(f"ACCURACY: {accuracy:.2%}")
    print("------------------------\n")


def main():
    parser = argparse.ArgumentParser()
   
    parser.add_argument('--eval_model', type=str, default='deepseek-r1')
    
    global args
    args = parser.parse_args()
    
    evaluate_all_test()

if __name__ == '__main__':
    main()